import numpy as np
import multiprocessing as mp
import pandas as pd
import time

# Sampling script: Simulation_MultiHeadAttention_sqrtn_lowrank.py
# Low‐rank attention: inner product dimension is n_H instead of n.

# Simulation parameters will be set via params dict
# n: Full network width
# n_H: Per‐head width (so that n = H * n_H)
# n: Network width
# s: Spatial dimension
# H: Number of heads
# num_runs: Number of simulation runs
# num_processes: Number of parallel processes
# seed: Random seed
# mc_runs: Monte Carlo runs for theoretical density
# C: Clipping threshold for psi

# Globals (initialized in init_globals)
n = None
n_H = None
s = None
H = None
num_runs = None
num_processes = None
seed = None
mc_runs = None
C = None

# Weight sampling W ~ N(0,1/scale_dim)
def scaled_weights(shape, scale_dim):
    return np.random.randn(*shape) / np.sqrt(scale_dim)

# Initialize globals in each process or main
def init_globals(params):
    global n, n_H, s, H, num_runs, num_processes, seed, mc_runs, C
    n = params.get('n', n)
    s = params.get('s', s)
    H = params.get('H', H)
    n_H = params.get('n_H', n // H)  # per‐head dimension
    num_runs = params.get('num_runs', num_runs)
    num_processes = params.get('num_processes', num_processes)
    seed = params.get('seed', seed)
    mc_runs = params.get('mc_runs', mc_runs)
    C = params.get('C', C)
    np.random.seed(seed)

# Softmax function
def softmax(x, axis=-1):
    e_x = np.exp(x)
    return e_x / np.sum(e_x, axis=axis, keepdims=True)

# Single empirical run: output matrix Y shape (s,n)
def single_run(_):
    # Sample initial vector h ~ N(0, I_n)
    h = np.random.randn(n)
    #  Compute h^i = W^i h for i=1…s
    Wstack = scaled_weights((s, n, n), n).reshape(s, n, n)
    H_mat  = Wstack @ h # Each row is h^i
    # Apply clipping activation elementwise to each h^i
    X = np.clip(H_mat, -C, C) # shape (s, n)
    accum = np.zeros((s, n))
    for _ in range(H):
        Wq = scaled_weights((n, n_H), n) # (n, n_H)
        Wk = scaled_weights((n, n_H), n) # (n, n_H)
        Wv = scaled_weights((n, n_H), n) # (n, n_H)
        Wo = scaled_weights((n_H, n), n_H) # (n_H, n)
        Q = X @ Wq; K = X @ Wk; V = X @ Wv # (s, n_H)
        tV = V @ Wo # (s, n)
        G = Q.dot(K.T) / np.sqrt(n_H)
        A = softmax(G, axis=1)
        accum += A @ tV
    return accum

# Simulate empirical: returns array (num_runs, s, n)
def simulate_empirical(params):
    init_globals(params)
    with mp.Pool(processes=params['num_processes'], initializer=init_globals, initargs=(params,)) as pool:
        out = pool.map(single_run, range(params['num_runs']))
    return np.stack(out)

# Theoretical Monte Carlo for n^{-1/2}: returns array (mc_runs, s)
def simulate_theoretical(params):
    init_globals(params)
    p = np.random.randn(params['mc_runs'], H, s, s)
    Z = np.random.randn(params['mc_runs'], H, s)
    y = np.zeros((params['mc_runs'], s))
    for i in range(params['mc_runs']):
        for a in range(H):
            logits = p[i, a]
            probs = softmax(logits, axis=1)
            y[i] += probs.dot(Z[i, a])
    return y

if __name__ == '__main__':
    # Settings for experiments
    n_H = 64
    base = {
        'n_H':  n_H,
        's': 4,
        'num_runs': 300000,
        'num_processes': 18,
        'seed': 0,
        'mc_runs': 300000,
        'C': 100
    }
    H_vals = [4 ** i for i in range(3)]

    records = []

    start = time.time()

    for H_val in H_vals:
        params = base.copy(); params['H'] = H_val
        params['n'] = n_H * H_val
        emp = simulate_empirical(params)
        y_emp = emp[:, 0, 0]

        theo = simulate_theoretical(params)
        y_theo = theo[:, 0]
        for ye, yt in zip(y_emp, y_theo):
            records.append({'param': H_val, 'y_emp': ye, 'y_theo': yt})

    # Save to CSV
    pd.DataFrame(records).to_csv('data_vary_n_and_H.csv', index=False)

    print(f"Sampling done in {time.time()-start:.2f} seconds")